Skip to content

[JAX] Expert Parallelism: JAX primitives + VJPs#3036

Open
phu0ngng wants to merge 6 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax
Open

[JAX] Expert Parallelism: JAX primitives + VJPs#3036
phu0ngng wants to merge 6 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax

Conversation

@phu0ngng
Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng commented May 22, 2026

Summary

Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the nvte_ep_* C API, a Python wrapper with custom_vjp for autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCL ncclEpDispatch/ncclEpCombine are exposed as XLA primitives and work with CUDA-graph capture.

Implementation

Public Python API (transformer_engine/jax/ep.py)

from transformer_engine.jax.ep import (
    EpHandle,        # opaque (id, handle_mem) pair from ep_prepare
    ep_bootstrap,    # one-shot per-process: init NCCL comm + nvte_ep_initialize
    ep_dispatch,     # custom_vjp-wrapped dispatch 
    ep_combine,      # custom_vjp-wrapped combine

ep_dispatch / ep_combine are jax.custom_vjp functions: forward is the FFI primitive, backward calls the matching nvte_ep_*_bwd FFI primitive directly (no ep_prepare in the bwd — routing state is already cached in handle.mem). Note that ep_dispatch also calls ep_prepare in the forward path, which all-gathers and preprocesses routing maps.

XLA FFI bindings (transformer_engine/jax/csrc/extensions/ep.cpp)

Five XLA_FFI_DEFINE_HANDLER_SYMBOL entries — EpPrepareHandler, EpDispatchHandler, EpCombineHandler, EpDispatchBwdHandler, EpCombineBwdHandler — each calling the corresponding nvte_ep_* C entry point. All marked FFI_CudaGraph_Traits so they capture cleanly. handle_id is a static FFI attribute baked at jit trace time.

Primitives + Python layer (transformer_engine/jax/cpp_extensions/ep.py, +951 lines)

Standard TE primitive plumbing: abstract_eval (shape/dtype inference), lowering, impl, outer_primitive registration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).

Sharding (transformer_engine/jax/sharding.py, +12 lines)

Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.

Build wiring (build_tools/jax.py, +41 lines)

Threads NCCL EP linkage through the JAX transformer_engine_jax extension. No new top-level build flags; rides on the parent PR's NVTE_BUILD_WITH_NCCL_EP.

Tests & example

  • tests/jax/test_multi_process_ep.py (+690 lines): 13 tests covering bootstrap, ep_prepare shape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing), custom_vjp fwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).
  • tests/jax/multi_process_launch_ep.sh: 4-rank launcher; sets XLA_FLAGS to keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).
  • examples/jax/ep/ep_moe.py (+394 lines) + run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison --check that verifies fwd+bwd vs a single-process reference.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR lands the JAX Expert Parallelism (EP) bindings: XLA FFI handlers wrapping the nvte_ep_* C API, jax.custom_vjp-wrapped ep_dispatch/ep_combine with mesh-aware sharding rules, build wiring for the NCCL EP submodule, a multi-process test suite, and an end-to-end MoE example.

  • transformer_engine/jax/cpp_extensions/ep.py (+955 lines): five new primitives (EpPrepare, EpDispatch, EpCombine, EpDispatchBwd, EpCombineBwd) each with abstract_eval, lowering, impl, partition, and shardy_sharding_rule.
  • transformer_engine/jax/csrc/extensions/ep.cpp (+539 lines): five XLA_FFI_DEFINE_HANDLER_SYMBOL entries, NCCL communicator lifetime management via EpInstanceState.
  • transformer_engine/jax/ep.py (+303 lines): public ep_bootstrap, ep_dispatch and ep_combine with custom_vjp.

Confidence Score: 4/5

Safe to merge with one fix: the dispatch-backward partition function declares an output sharding with the wrong rank for grad_topk_weights, causing a JAX compile-time error in any multi-device training run that backpropagates through ep_dispatch.

The dispatch-backward partition function specifies PartitionSpec(*resolved, None) for grad_topk_weights, producing a spec one rank wider than the tensor's actual shape. Under SPMD JIT with any mesh, JAX will reject the sharding at compile time. The bug is latent in all backward-training paths and easy to hit once EP is exercised in a real training loop. The rest of the primitive stack, the C++ FFI layer, and the custom_vjp math all look correct.

transformer_engine/jax/cpp_extensions/ep.py — specifically the EpDispatchBwdPrimitive.partition method (lines 763-766).

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/ep.py New file (+955 lines): five EP primitives with abstract_eval, lowering, partition, and shardy_sharding_rule. EpDispatchBwdPrimitive.partition uses PartitionSpec(*resolved, None) for grad_topk_weights, producing a spec one rank wider than the tensor, causing a JAX error at SPMD compile time.
transformer_engine/jax/csrc/extensions/ep.cpp New file (+539 lines): five XLA FFI handlers plus NCCL comm lifetime management. topk_weights unconditionally wrapped as DType::kFloat32 without dtype validation (flagged in previous review). Otherwise structurally sound.
transformer_engine/jax/ep.py New file (+303 lines): public ep_bootstrap with input validation, ep_dispatch/ep_combine with custom_vjp. VJP math looks correct; sharding constraint re-pinning in backward is well-handled.
build_tools/jax.py Adds NCCL EP linkage with hard RuntimeError when submodule header is missing or an arch < 90 is in NVTE_CUDA_ARCHS. Inconsistency with setup.py graceful disable noted in previous review threads.
transformer_engine/jax/sharding.py Adds ep_resource field to MeshResource dataclass and ep_axis_size() helper. Clean, non-breaking addition.
tests/jax/test_multi_process_ep.py 690-line multi-process test suite. All tests require SM>=90 hardware, so the partition-spec bug in EpDispatchBwd would not be caught in CI without Hopper GPUs.

Sequence Diagram

sequenceDiagram
    participant PY as Python (ep.py)
    participant Prim as JAX Primitives (cpp_extensions/ep.py)
    participant FFI as XLA FFI (ep.cpp)
    participant NCCL as NCCL EP (nvte_ep_*)

    Note over PY: ep_bootstrap()
    PY->>FFI: SetEpBootstrapParams(uid, ep_size, ...)
    FFI->>NCCL: ncclCommInitRank + nvte_ep_initialize

    Note over PY: ep_dispatch() forward
    PY->>Prim: ep_prepare(topk_idx)
    Prim->>FFI: EpPrepareHandler
    FFI->>NCCL: nvte_ep_prepare
    NCCL-->>PY: token_counts, EpHandle

    PY->>Prim: ep_dispatch_fwd(handle, tokens, topk_weights)
    Prim->>FFI: EpDispatchHandler
    FFI->>NCCL: nvte_ep_dispatch
    NCCL-->>PY: recv_tokens, recv_topk_weights

    Note over PY: Expert FFN runs on recv_tokens

    PY->>Prim: ep_combine_fwd(handle, weighted_expert_out)
    Prim->>FFI: EpCombineHandler
    FFI->>NCCL: nvte_ep_combine
    NCCL-->>PY: combined output

    Note over PY: ep_dispatch() backward
    PY->>Prim: ep_dispatch_bwd(handle, g_recv_tokens, g_recv_topk_weights)
    Prim->>FFI: EpDispatchBwdHandler
    FFI->>NCCL: nvte_ep_dispatch_bwd
    NCCL-->>PY: grad_tokens, grad_topk_weights

    PY->>Prim: ep_combine_bwd(handle, g_result)
    Prim->>FFI: EpCombineBwdHandler
    FFI->>NCCL: nvte_ep_combine_bwd
    NCCL-->>PY: grad_expert_out
Loading

Reviews (5): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread build_tools/jax.py
Comment thread build_tools/jax.py
Comment thread transformer_engine/jax/cpp_extensions/ep.py
}

private:
EpCommManager() = default;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use stateful FFI calls we could tie to EP communicator to the lifetime of the jax computation rather than the process.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool to learn! I will update it.

Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts,
Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) {
auto topk_dims = topk_idx.dimensions();
NVTE_CHECK(topk_dims.size() >= 2,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.

@phu0ngng phu0ngng requested a review from tdophung May 22, 2026 15:51
@phu0ngng
Copy link
Copy Markdown
Collaborator Author

I would appreciate your help to review this PR @tdophung @jberchtold-nvidia!
Please focus on the changes in the JAX side, as the TE/Common ones will be discussed in #3034

Comment thread examples/jax/ep/ep_moe.py Outdated
kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:])

@jax.jit
def step(idx, toks, w, lk):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does lk stand for?

Comment thread tests/jax/multi_process_launch_ep.sh Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
leading = _ep_leading_dims(is_outer)
recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype)
recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32)
workspace_aval = jax.core.ShapedArray(topk_idx_aval.shape, jnp.int64)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above about int64

Comment thread examples/jax/ep/ep_moe.py
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated
Comment on lines +81 to +82
assert ret == 0, f"ncclGetUniqueId failed with code {ret}"
uid_bytes = bytes(uid_arr)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 assert disabled by -O in ctypes UID path

assert ret == 0 is silently elided when Python runs under the -O optimisation flag (common in production or Numba/Conda environments). If ncclGetUniqueId fails, uid_bytes would be all zeros; the all-gather propagates those zeros to every rank in the EP group, causing ncclCommInitRank to either produce mismatched communicators or hang indefinitely with no diagnostic message.

Suggested change
assert ret == 0, f"ncclGetUniqueId failed with code {ret}"
uid_bytes = bytes(uid_arr)
ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p))
if ret != 0:
raise RuntimeError(f"ncclGetUniqueId failed with code {ret}")

phu0ngng added 4 commits May 23, 2026 19:36
…em_reloc gating

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…s, MoE example)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants